{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import time\n",
    "import torch\n",
    "import pytorch_mask_rcnn as pmr\n",
    "    \n",
    "    \n",
    "def main(args):\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() and args.use_cuda else \"cpu\")\n",
    "    if device.type == \"cuda\":\n",
    "        pmr.get_gpu_prop(show=True)\n",
    "    print(\"\\ndevice: {}\".format(device))\n",
    "    \n",
    "    d_test = pmr.datasets(args.dataset, args.data_dir, \"val2017\", train=True) # set train=True for eval\n",
    "\n",
    "    print(args)\n",
    "    num_classes = len(d_test.dataset.classes) + 1\n",
    "    model = pmr.maskrcnn_resnet50(False, num_classes).to(device)\n",
    "    \n",
    "    checkpoint = torch.load(args.ckpt_path, map_location=device)\n",
    "    model.load_state_dict(checkpoint[\"model\"])\n",
    "    print(checkpoint[\"eval_info\"])\n",
    "    del checkpoint\n",
    "    torch.cuda.empty_cache()\n",
    "\n",
    "    print(\"evaluating only...\")\n",
    "    B = time.time()\n",
    "    eval_output, iter_eval = pmr.evaluate(model, d_test, device, args)\n",
    "    B = time.time() - B\n",
    "    print(eval_output)\n",
    "    print(\"\\ntotal time of this evaluation: {:.2f} s, speed: {:.2f} FPS\".format(B, args.batch_size / iter_eval))\n",
    "    \n",
    "    \n",
    "if __name__ == \"__main__\":\n",
    "    import argparse\n",
    "    parser = argparse.ArgumentParser()\n",
    "    parser.add_argument(\"--dataset\", default=\"coco\")\n",
    "    parser.add_argument(\"--data-dir\")\n",
    "    parser.add_argument(\"--iters\", type=int, default=-1)\n",
    "    \n",
    "    args = parser.parse_args([]) # for Jupyter Notebook\n",
    "    \n",
    "    args.use_cuda = True\n",
    "    args.data_dir = \"/data/coco2017\"\n",
    "    args.ckpt_path = \"/ckpt/maskrcnn_coco.pth\"\n",
    "    args.results = os.path.join(os.path.dirname(args.ckpt_path), \"results.pth\")\n",
    "    \n",
    "    main(args)\n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
